import numpy
import struct
import sys
import os
def align_file(f, n):
    off = f.seek(0, 1)
    mask = n - 1
    f.read((off + mask & ~mask) - off)
def read_traj_frame(f, v=True):
    nxtstep = f.read(16)
    if len(nxtstep) < 16:
        return False
    istep, ncells = struct.unpack('QQ', nxtstep)
    celltags = []
    cellx = []
    cellv = []
    for i in range(0, ncells):
        [natom] = struct.unpack('Q', f.read(8))
        offx, offy, offz = struct.unpack('ddd', f.read(24))
        tags = numpy.frombuffer(f.read(natom*8), dtype=numpy.int64)
        ts = numpy.frombuffer(f.read(natom*4), dtype=numpy.int32)
        if natom & 1:
            f.read(4)
        xs = numpy.frombuffer(f.read(natom*24), dtype=numpy.double).reshape([-1,3]) + numpy.asarray([offx, offy, offz])
        if v:
            vs = numpy.frombuffer(f.read(natom*24), dtype=numpy.double).reshape([-1,3])
        celltags.append(tags)
        cellx.append(xs)
        cellv.append(vs)
    
    return istep, numpy.concatenate(celltags), numpy.concatenate(cellx), numpy.concatenate(cellv)
def merge_traj_frames(frames):
    isteps = numpy.asarray(list(map(lambda fm: fm[0], frames)))
    tags = numpy.concatenate(list(map(lambda fm: fm[1], frames)))
    xoo = numpy.concatenate(list(map(lambda fm: fm[2], frames)))
    voo = numpy.concatenate(list(map(lambda fm: fm[3], frames)))
    if isteps.max() != isteps.min():
        print("istep misalign detected, exitting", file=sys.stderr)
        sys.exit(1)
    natom = tags.max() + 1
    x = numpy.zeros((natom, 3))
    x[tags] = xoo
    v = numpy.zeros((natom, 3))
    v[tags] = voo
    return isteps[0], x, v

def read_frames(trjs, offset=numpy.asarray([0, 0, 0])):
    while True:
        for f in trjs:
            align_file(f, 4096)
        frms = [read_traj_frame(f) for f in trjs]
        if not frms[0]:
            return
        istep, x, v, = merge_traj_frames(frms)
        yield x + offset, v + offset

import ctypes
class DCDHeader(ctypes.Structure):
    _fields_ = [
        ("nbhead_begin"  , ctypes.c_int),     #/* = 84 (bytes for "head" section) */
        ("title"         , ctypes.c_char*4),  #/* = "CORD" */
        ("numframes"     , ctypes.c_int),     #/*  number of frames, updated as we write */
        ("firststep"     , ctypes.c_int),     #/*  first step number */
        ("framestepcnt"  , ctypes.c_int),     #/*  number of steps per frame */
        ("numsteps"      , ctypes.c_int),     #/*  total number sim steps + first step, updated */
        ("zero5"         , ctypes.c_int*5),   #/* = 0 */
        ("timestep"      , ctypes.c_float),   #/*  time step */
        ("iscell"        , ctypes.c_int),     #/* = 1 if unitcell info, = 0 if not */
        ("zero8"         , ctypes.c_int*8),   #/* = 0 */
        ("charmmversion" , ctypes.c_int),     #/* = 24 (CHARMM 24 format DCD file) */
        ("nbhead_end"    , ctypes.c_int),     #/* = 84 (bytes for "head" section) */
        ("nbtitle_begin" , ctypes.c_int),     #/* = 164 (bytes for "title" section) */
        ("numtitle"      , ctypes.c_int),     #/* = 2 (i.e. two 80-char title strings) */
        ("title_str"     , ctypes.c_char*80), #/*  remarks title string */
        ("create_str"    , ctypes.c_char*80), #/*  remarks create time string */
        ("nbtitle_end"   , ctypes.c_int),     #/* = 164 (bytes for "title" section) */
        ("nbnatoms_begin", ctypes.c_int),     #/* = 4 (bytes for "natoms" section) */
        ("natoms"        , ctypes.c_int),     #/*  number of atoms */
        ("nbnatoms_end"  , ctypes.c_int),     #/* = 4 (bytes for "natoms" section) */
    ]
class DCDCell(ctypes.Structure):
    _fields_ = [
        ("dummy1",       ctypes.c_int      ), #/*  (padding to 8-byte word alignment) */
        ("nbcell_begin", ctypes.c_int      ), #/* = 48 (bytes for "unitcell" section) */
        ("unitcell",     ctypes.c_double*6 ), #/*  describes periodic cell */
        ("nbcell_end",   ctypes.c_int      ), #/* = 48 (bytes for "unitcell" section) */
        ("dummy2",       ctypes.c_int      ), #/*  (padding to 8-byte word alignment) */
    ]

def DCD_hdr_init():
    hdr = DCDHeader()
    hdr.nbhead_begin = 84
    hdr.title = "CORD".encode()
    hdr.numframes = 0
    hdr.firststep = 0
    hdr.framestepcnt = 1
    hdr.numsteps = 1
    hdr.timestep = 0
    hdr.iscell = 1
    hdr.charmmversion = 24
    hdr.nbhead_end = 84
    hdr.nbtitle_begin = 164
    hdr.numtitle = 2
    hdr.title_str = "REMARKS ESMD trajectory file".encode()
    hdr.create_str = "REMARKS MDIO DCD file created 28".encode()
    hdr.nbtitle_end = 164
    hdr.nbnatoms_begin = 4
    hdr.natoms = 0
    hdr.nbnatoms_end = 4
    return hdr
def parse_esmd_conf(path):
    conf = {}
    for line in open(path):
        if not line.startswith('#') and len(line.split('=')) == 2:
            k, v = map(str.strip, line.split('='))
            conf[k.lower()] = v
    conf["cell"] = list(map(float, conf["cell"].split(' ')))
    conf["dt"] = float(conf["dt"])
    conf["trajfreq"] = int(conf["trajfreq"])
    return conf

def find_traj_files(prefix):
    i = 0
    files = []
    while os.path.exists("%s-%06d.estraj" % (prefix, i)):
        files.append(open("%s-%06d.estraj" % (prefix, i), "rb"))
        print("Found trajectory file %s-%06d.estraj" % (prefix, i), file=sys.stderr)
        i += 1
        
    return files

def init_dcd_file(conf):
    dcdfile = open("%s.dcd" % conf["trajectory"], "wb")
    hdr = DCD_hdr_init()
    hdr.timestep = conf["dt"]
    dcdfile.write(hdr)
    #dcdfile.write(bytes(ctypes.c_int(0)))
    return hdr, dcdfile

def append_frame(hdr, dcdfile, x):
    dcdfile.write(bytes(ctypes.c_int32(4*len(x))))
    dcdfile.write(x[:, 0].astype(numpy.float32).tobytes())
    dcdfile.write(bytes(ctypes.c_int32(4*len(x))))

    dcdfile.write(bytes(ctypes.c_int32(4*len(x))))
    dcdfile.write(x[:, 1].astype(numpy.float32).tobytes())
    dcdfile.write(bytes(ctypes.c_int32(4*len(x))))
    
    dcdfile.write(bytes(ctypes.c_int32(4*len(x))))
    dcdfile.write(x[:, 2].astype(numpy.float32).tobytes())
    dcdfile.write(bytes(ctypes.c_int32(4*len(x))))

    hdr.numframes+=1
    hdr.natoms = len(x)
    hdr.numsteps += 1
def close_dcd(hdr, dcdfile):
    dcdfile.seek(0, 0)
    dcdfile.write(hdr)
    dcdfile.close()
if len(sys.argv) <= 1:
    print("Usage: trajconv.py conffile", file=sys.stderr)

conf = parse_esmd_conf(sys.argv[1])
os.environ["CONFDIR"] = os.path.dirname(os.path.realpath(sys.argv[1]))
os.environ["WORKDIR"] = os.environ["PWD"]
root = os.path.expandvars(conf["root"])
os.chdir(root)
if root != os.environ["WORKDIR"]:
    print("Swtiched working directory to %s" % root, file=sys.stderr)
hdr, dcdfile = init_dcd_file(conf)
offset = numpy.asarray([-conf["cell"][0], -conf["cell"][1], -conf["cell"][2]])

cell = DCDCell()
cell.nbcell_begin = ctypes.c_int(48)
cell.unitcell[0] = conf["cell"][3] - conf["cell"][0]
cell.unitcell[1] = 0
cell.unitcell[2] = conf["cell"][4] - conf["cell"][1]
cell.unitcell[3] = 0
cell.unitcell[4] = 0
cell.unitcell[5] = conf["cell"][5] - conf["cell"][2]
cell.dummy2 = 0
cell.dummy1 = 0
cell.nbcell_end = ctypes.c_int(48)
print("Unit cell from conf file:", file=sys.stderr)
for i in range(3):
    for j in range(i*(i+1)//2, i*(i+1)//2+i+1):
        print("%9.3f" % (cell.unitcell[j]), end="", file=sys.stderr)
    print("", file=sys.stderr)

for x, v in read_frames(find_traj_files(conf["trajectory"]), offset):
    dcdfile.write(bytes(cell)[4:60])
    append_frame(hdr, dcdfile, x)
    print("\rWrote %10d frames" % hdr.numframes, end="", file=sys.stderr)
print(file=sys.stderr)
close_dcd(hdr, dcdfile)
print("Trajectory file wrote in %s.dcd" % os.path.relpath(conf["trajectory"], os.environ["WORKDIR"]), file=sys.stderr)
